import os
import json
import torch
from tqdm import tqdm
import sys
import re
import base64
import cv2
from PIL import Image
from io import BytesIO
import random
import numpy as np
import pandas as pd
import datetime
import warnings
import logging
from transformers import MllamaForConditionalGeneration, AutoProcessor
import argparse

# Set up logging and warnings
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

parser = argparse.ArgumentParser()
parser.add_argument("--start", type=int, default=0, help="Starting sample index (inclusive)")
parser.add_argument("--end", type=int, default=None, help="Ending sample index (inclusive). If not provided, process till the last sample")
parser.add_argument("--output_dir", type=str, default=".", help="Directory to save result JSONs")
parser.add_argument("--num_runs", type=int, default=1, help="Number of stochastic inference runs per persona (set to 1 for single-run deterministic behaviour)")
parser.add_argument("--similarity_json", type=str, default=None, help="Path to similarity JSON file")
args = parser.parse_args()

# Similarity data
if args.similarity_json is not None:
    SIM_JSON_PATH = args.similarity_json
else:
    SIM_JSON_PATH = os.path.join(os.path.dirname(__file__), "similaritywebaes.json")

if os.path.exists(SIM_JSON_PATH):
    try:
        with open(SIM_JSON_PATH, "r") as f:
            SIMILARITY_DATA = json.load(f)
    except Exception:
        SIMILARITY_DATA = {}
else:
    SIMILARITY_DATA = {}

def get_model_and_processor(model_dir="meta-llama/Llama-3.2-90B-Vision"):
    logger.info(f"Loading model and processor from {model_dir}")

    if "LOCAL_RANK" not in os.environ:
        os.environ["LOCAL_RANK"] = "0"
        logger.info("Environment variable LOCAL_RANK was not set – defaulting to 0 (single-GPU mode).")

    processor = AutoProcessor.from_pretrained(model_dir)
    model = MllamaForConditionalGeneration.from_pretrained(
        model_dir,
        torch_dtype="auto",
        device_map={"": 0},  # explicit single-GPU mapping to bypass tensor-parallel auto-init
        load_in_4bit=True
    )
    return model, processor

def _data_url_to_pil(data_url):
    try:
        header, encoded = data_url.split(",", 1)
        img_bytes = base64.b64decode(encoded)
        return Image.open(BytesIO(img_bytes)).convert("RGB")
    except Exception:
        return None

def inference_batch(image_urls, prompts, sys_prompts, model, processor, max_new_tokens=512):
    messages_batch = []
    for image_url, prompt, sys_prompt in zip(image_urls, prompts, sys_prompts):
        messages_batch.append([
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": [
                {"type": "image", "image": image_url},
                {"type": "text", "text": sys_prompt + prompt},
            ]}
        ])

    # Prepare the input for the processor
    texts = [
        processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        for messages in messages_batch
    ]


    # Convert image URLs to PIL images (one per sample)
    pil_images = []
    for image_url in image_urls:
        pil_img = _data_url_to_pil(image_url)
        if pil_img is None:
            pil_img = Image.new("RGB", (256, 256), "white")
        pil_images.append(pil_img)

    inputs = processor(
        text=texts,
        images=pil_images,
        padding=True,
        return_tensors="pt",
    ).to(model.device)

    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.85,
            top_p=0.9,
            pad_token_id=processor.tokenizer.eos_token_id,
        )

    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_texts = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    return output_texts

def frame_to_data_url(frame_bgr):
    try:
        # Check if the frame is valid
        if frame_bgr is None or frame_bgr.size == 0:
            return None
        
        # Convert the BGR frame (OpenCV format) to RGB
        frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)

        # Convert the RGB frame to a PIL Image
        image = Image.fromarray(frame_rgb)
        image = image.resize((256, 256), Image.LANCZOS)
        # Create a BytesIO buffer to hold the image data
        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        buffered.seek(0)

        # Encode the image data in base64
        base64_encoded_data = base64.b64encode(buffered.read()).decode('utf-8')

        # Construct the data URL
        return f"data:image/jpeg;base64,{base64_encoded_data}"
    except Exception as e:
        print(f"Error in frame_to_data_url: {e}")
        return None

persona_prompts = {
    "18-24_female": """You are a woman aged 18–24. You're fluent in digital aesthetics, raised on platforms like TikTok and Instagram. You notice instantly if something has a vibe—bold colors, expressive fonts, emotional tone, or modern, fun design. Websites that are cluttered, generic, or try-hard are less likely to appeal to you.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on its visual design, layout, color scheme, and content.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "18-24_male": """You are a man aged 18–24. You're used to fast-scroll content and visual punch—memes, Twitch, TikTok, YouTube. You like websites that grab attention fast: bold layouts, smart design, or a bit of edge. If a website feels outdated, cluttered, or boring, it loses your interest quickly.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on visuals, usability, and vibe.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "25-34_female": """You are a woman aged 25–34. You appreciate modern, polished websites that feel aligned with your lifestyle—whether it's wellness, creativity, relationships, or career. You like clean layouts, elegant color palettes, and visuals that are both pretty and purposeful.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on design, clarity, aesthetics, and content.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "25-34_male": """You are a man aged 25–34. You value strong, clear, and modern visuals. You're likely to appreciate websites that are bold but not messy—clean grids, high contrast, sharp fonts, and relevant content (fitness, tech, ambition, money).

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on its layout, visual punch, and message.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "35-44_female": """You are a woman aged 35–44. You're drawn to websites that are intentional, emotionally intelligent, and visually clean. Family, meaning, and beauty in simplicity appeal to you more than trend-driven clutter.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on its design, clarity, and emotional tone.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "35-44_male": """You are a man aged 35–44. You like websites that are grounded, practical, and cleanly designed. Strong layouts, good use of space, and purpose-driven content grab your attention more than visual noise.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on structure, relevance, and visual balance.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "45-54_female": """You are a woman aged 45–54. You like websites that are calm, clear, and visually composed. Design that feels warm, thoughtful, and emotionally grounded appeals more than flashy visuals or trendy noise.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on clarity, emotional tone, and visual presentation.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "45-54_male": """You are a man aged 45–54. You prefer websites that are easy to navigate, focused, and visually grounded. You're drawn to sites that reflect purpose and clarity over trend or flash.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on usability, structure, and message.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "55+_female": """You are a woman aged 55 or older. You appreciate websites that feel meaningful, visually calm, and easy to understand. Gentle color palettes, clear fonts, and emotionally warm content make a big difference.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on design simplicity and emotional tone.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "55+_male": """You are a man aged 55 or older. You value websites that are straightforward, honest, and easy to engage with. Flashy or cluttered pages can feel frustrating, while clear structure and meaningful content feel worthwhile.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on clarity, usefulness, and visual comfort.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]"""
}

# Helper to extract image inputs from chat-template messages
def process_vision_info(messages_batch):
    """
    Given a batch of chat-style messages, collect the image and video inputs that
    the Llama-3 Vision `AutoProcessor` expects.

    Args:
        messages_batch: List[List[dict]] – each sample is a list of chat turns.
            Every turn has a ``content`` field that can be a list containing
            dicts with keys ``type`` and ``image``/``video``.

    Returns:
        image_inputs: List of lists of images (one list per sample). If a sample
            has no images, an empty list is used so batching still works.
        video_inputs: List of lists of videos (one list per sample). Currently
            we do not supply videos, but we keep the placeholder for API
            compatibility (empty list per sample).
    """

    image_inputs = []
    video_inputs = []

    for sample in messages_batch:
        sample_images = []
        sample_videos = []

        for turn in sample:
            content = turn.get("content", [])
            if not isinstance(content, list):
                continue

            for item in content:
                if not isinstance(item, dict):
                    continue
                if item.get("type") == "image":
                    sample_images.append(item.get("image"))
                elif item.get("type") == "video":
                    sample_videos.append(item.get("video"))

        image_inputs.append(sample_images)
        video_inputs.append(sample_videos)

    return image_inputs, video_inputs

def batch_verbalize(batch_data, batch_size=16):
    all_results = []

    for idx in tqdm(range(0, len(batch_data), batch_size), desc="Model Inference"):
        current_batch = batch_data[idx : idx + batch_size]

        messages_for_batch = []
        for prompt, sys_prompt, img_tuples in current_batch:
            # last image in the tuple list is the target
            data_url = img_tuples[-1][0] if img_tuples else None
            messages_for_batch.append([
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": [
                    {"type": "image", "image": data_url},
                    {"type": "text", "text": "Strictly answer ONLY from the perspective of the given persona. Your response must be specific to the persona described. " + sys_prompt + prompt},
                ]}
            ])

        try:
            texts = [
                processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                for messages in messages_for_batch
            ]

            # Convert each data URL to a PIL image (one per sample)
            pil_images_batch = []
            for msg in messages_for_batch:
                # The first "user" content is always the image dictionary we inserted
                img_url = None
                for item in msg[1]["content"]:
                    if item.get("type") == "image":
                        img_url = item["image"]
                        break
                pil = _data_url_to_pil(img_url) if img_url else None
                if pil is None:
                    pil = Image.new("RGB", (256, 256), "white")
                pil_images_batch.append(pil)

            inputs = processor(
                text=texts,
                images=pil_images_batch,
                padding=True,
                return_tensors="pt",
            ).to(model.device)

            with torch.no_grad():
                gen_ids = model.generate(
                    **inputs,
                    max_new_tokens=1200,
                    do_sample=True,
                    temperature=0.85,
                    top_p=0.9,
                    pad_token_id=processor.tokenizer.eos_token_id,
                )
            
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, gen_ids)
            ]
            outputs = processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
            all_results.extend(outputs)

        except Exception as e:
            print(f"Error during batch inference: {e}")
            all_results.extend([f"Error: {e}"] * len(current_batch))

    return all_results

import pandas as pd
import re

def safe_load_image(image_path):
    try:
        # First try with cv2
        image = cv2.imread(image_path)
        if image is not None and image.size > 0:
            return image
    except Exception as e:
        print(f"cv2.imread failed for {image_path}: {e}")
    
    try:
        # Fallback: try with PIL and convert to cv2 format
        pil_image = Image.open(image_path)
        pil_image = pil_image.convert('RGB')
        cv2_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
        return cv2_image
    except Exception as e:
        print(f"PIL fallback failed for {image_path}: {e}")
    
    return None

test_filename = "website-aesthetics-datasets/rating-based-dataset/preprocess/test_list.csv"
try:
    df = pd.read_csv(test_filename)
    print(f"Loaded dataset with {len(df)} samples")
except Exception as e:
    print(f"Error loading dataset: {e}")
    sys.exit(1)

def extract_score_from_response(resp):
    answer = None

    # Pattern 1: explicit "Answer: X" (preferred)
    answer_pattern = re.search(r'Answer:\s*(\d+(?:\.\d+)?)', resp, re.IGNORECASE)
    if answer_pattern:
        try:
            answer = float(answer_pattern.group(1))
            answer = max(0.0, min(10.0, answer))
        except ValueError:
            pass

    # Pattern 2: generic "score:" or "rating:" mention
    if answer is None:
        score_pattern = re.search(r'(?:score|rating):\s*(\d+(?:\.\d+)?)', resp, re.IGNORECASE)
        if score_pattern:
            try:
                answer = float(score_pattern.group(1))
                answer = max(0.0, min(10.0, answer))
            except ValueError:
                pass

    # Pattern 3: fallback – last valid number in range 0-10
    if answer is None:
        number_matches = re.findall(r'\b(\d+(?:\.\d+)?)\b', resp)
        valid_scores = [float(n) for n in number_matches if 0 <= float(n) <= 10]
        if valid_scores:
            answer = valid_scores[-1]

    return answer

def prepare_sample_data(df, sample_indices):
    batch_samples = []
    skipped_samples = []
    
    print("Preparing sample data...")
    for i in tqdm(sample_indices, desc="Loading samples"):
        try:
            d = df.iloc[i]
            value = d.to_dict()
            image_path = 'website-aesthetics-datasets/rating-based-dataset/images/'+d['image'].replace('_resized','').lstrip('/')
            
            # Check if image file exists
            if not os.path.exists(image_path):
                skipped_samples.append((i, f"Image file does not exist: {image_path}"))
                continue
            
            # Use safe image loading
            image = safe_load_image(image_path)
            
            # Check if main image loaded successfully
            if image is None:
                skipped_samples.append((i, f"Could not load image: {image_path}"))
                continue
                
            image_url = frame_to_data_url(image)
            if image_url is None:
                skipped_samples.append((i, f"Could not process image: {image_path}"))
                continue
            
            # Sample example images
            example_lines = []
            example_images = []
            valid_examples = 0

            # Similarity-based retrieval
            similar_list = SIMILARITY_DATA.get(str(i), {}).get("similar_images", [])
            for sim in similar_list:
                if valid_examples >= 5:
                    break
                try:
                    fname = sim["image"]
                    score = sim.get("mean_score", None)
                    img_path = 'website-aesthetics-datasets/rating-based-dataset/images/' + fname.replace('_resized', '').lstrip('/')
                    if not os.path.exists(img_path):
                        continue
                    img = safe_load_image(img_path)
                    if img is None:
                        continue
                    img_url = frame_to_data_url(img)
                    if img_url is None:
                        continue
                    example_lines.append(f"Score: {score:.1f}" if score is not None else "Score: N/A")
                    example_images.append((img_url, score))
                    valid_examples += 1
                except Exception:
                    continue

            # Random fallback if still <5
            if valid_examples < 5:
                other_indices = list(range(df.shape[0]))
                other_indices.remove(i)
                random.shuffle(other_indices)
                for idx in other_indices:
                    if valid_examples >= 5:
                        break
                    try:
                        row = df.iloc[idx]
                        fname = row['image']
                        score = row['mean_score']
                        img_path = 'website-aesthetics-datasets/rating-based-dataset/images/' + fname.replace('_resized', '').lstrip('/')
                        if not os.path.exists(img_path):
                            continue
                        img = safe_load_image(img_path)
                        if img is None:
                            continue
                        img_url = frame_to_data_url(img)
                        if img_url is None:
                            continue
                        example_lines.append(f"Score: {score:.1f}")
                        example_images.append((img_url, score))
                        valid_examples += 1
                    except Exception:
                        continue
            
            # Add the current image as the last one
            example_images.append((image_url, None))
            examples_text = "\n".join(example_lines)
            
            # Create the user prompt based on whether we have examples
            if valid_examples > 0:
                prompt = f"""Given the images below, the first {valid_examples} are example website screenshots with their likeability scores (on a 0-10 scale, see the list below). The last image is the one you should score. 

Carefully analyze the last website screenshot and provide a score between 0 to 10 based on how much people would like the website's visual design, layout, colors, typography, and overall aesthetic appeal.

Here are {valid_examples} example likeability scores (in order):
{examples_text}

Please evaluate the final website screenshot and provide your assessment."""
            
            batch_samples.append({
                'index': i,
                'value': value,
                'prompt': prompt,
                'example_images': example_images,
                'valid_examples': valid_examples
            })
            
        except Exception as e:
            skipped_samples.append((i, f"Unexpected error: {str(e)}"))
            continue
    
    return batch_samples, skipped_samples

NUM_RUNS = args.num_runs

# Load model and processor
logger.info("Loading model and processor...")
model, processor = get_model_and_processor("meta-llama/Llama-3.2-11B-Vision")
logger.info("Model and processor loaded successfully!")

response_dict = []
processed_count = 0
skipped_count = 0
error_count = 0

print(f"Starting processing of {len(df)} samples with batch processing and {len(persona_prompts)} personas...")

# Prepare list of sample indices according to CLI slice
if args.end is not None:
    all_indices = list(range(args.start, min(args.end + 1, len(df))))
else:
    all_indices = list(range(args.start, len(df)))

batch_samples, skipped_samples = prepare_sample_data(df, all_indices)
print(f"Prepared {len(batch_samples)} valid samples, skipped {len(skipped_samples)}")

for sample_idx, sample in enumerate(tqdm(batch_samples, desc="Samples")):
    try:
        persona_results = {}

        image_url = sample['example_images'][-1][0]
        prompt_text = sample['prompt']

        for persona_name, persona_prompt in persona_prompts.items():
            persona_results[persona_name] = {"all_responses": [], "all_predictions": []}

            for run_idx in range(NUM_RUNS):
                # Single-inference call similar to static.py style
                resp = inference_batch(
                    [image_url],
                    [prompt_text],
                    [persona_prompt],
                    model,
                    processor,
                )[0]

                persona_results[persona_name]["all_responses"].append(resp)

                if not resp.startswith("Error"):
                    score_val = extract_score_from_response(resp)
                    if score_val is not None:
                        persona_results[persona_name]["all_predictions"].append(score_val)

        # Compute persona means
        all_persona_means = []
        for pdata in persona_results.values():
            if pdata['all_predictions']:
                mean_pred = np.mean(pdata['all_predictions'])
            else:
                mean_pred = None
            pdata['mean_prediction'] = mean_pred
            pdata['num_valid_predictions'] = len(pdata['all_predictions'])
            if mean_pred is not None:
                all_persona_means.append(mean_pred)

        overall_mean = np.mean(all_persona_means) if all_persona_means else None

        # Attach results to sample value
        sample['value'].update({
            "persona_responses": persona_results,
            "overall_mean_prediction": overall_mean,
            "num_personas": len(persona_prompts),
            "valid_persona_predictions": len(all_persona_means)
        })
        response_dict.append(sample['value'])

        # Incremental save
        output_filename = os.path.join(
            args.output_dir,
            f'results_llama_persona_static_web_aes_ten_slice_{args.start}_{args.end if args.end is not None else "end"}.json'
        )
        with open(output_filename, 'w') as f_out:
            json.dump(response_dict, f_out, indent=4)
        print(f"💾 Saved progress after sample {sample_idx + 1}/{len(batch_samples)} -> {output_filename}")

    except Exception as e:
        print(f"[ERROR] Failed processing sample idx {sample_idx}: {e}")
        continue

print("All requested samples processed.")
print(f"Total processed: {len(response_dict)} | Skipped: {len(skipped_samples)}")
